# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test TPU-specific extensions to pallas print call."""

import functools
import re
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax._src.pallas import pallas_test_util as ptu
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

jax.config.parse_flags_with_absl()

P = jax.sharding.PartitionSpec

partial = functools.partial


@jtu.thread_unsafe_test_class()  # debug print test is not thread safe
class PallasCallPrintTest(ptu.PallasTPUTest):

  def test_debug_print(self):
    @functools.partial(
        self.pallas_call,
        out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
    )
    def kernel(x_ref, o_ref):
      pl.debug_print('It works!')

    x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128))
    compiled_kernel = (
        jax.jit(kernel)
        .lower(x)
        .compile({'xla_tpu_enable_log_recorder': 'true'})
    )
    with jtu.capture_stderr() as get_output:
      jax.block_until_ready(compiled_kernel(x))
    self.assertIn('It works!', get_output())

  def test_debug_print_in_index_map(self):
    def index_map(i):
      pl.debug_print('It works!')
      return (i, 0)

    @functools.partial(
        self.pallas_call,
        grid=(1,),
        in_specs=(pl.BlockSpec(index_map=index_map),),
        out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
    )
    def kernel(x_ref, o_ref):
      o_ref[...] = x_ref[...]

    x = jnp.arange(8 * 128, dtype=jnp.float32).reshape((8, 128))
    compiled_kernel = (
        jax.jit(kernel)
        .lower(x)
        .compile({'xla_tpu_enable_log_recorder': 'true'})
    )
    with jtu.capture_stderr() as get_output:
      jax.block_until_ready(compiled_kernel(x))
    self.assertIn('It works!', get_output())

  @parameterized.product(dtype=[jnp.int32, jnp.float32])
  def test_debug_print_with_values(self, dtype):
    @functools.partial(
        self.pallas_call,
        in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),),
        out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
    )
    def kernel(x_ref, o_ref):
      if dtype == jnp.int32:
        pl.debug_print('BEGIN1 x[0] == {}', x_ref[0])
        pl.debug_print(
            'BEGIN2 x[0] == {} ; x[1] == {} ; END', x_ref[0], x_ref[1]
        )
      else:
        pl.debug_print('BEGIN1 x[0] == ', x_ref[0])

    x = jnp.array([42, 24], dtype=dtype)
    compiled_kernel = (
        jax.jit(kernel)
        .lower(x)
        .compile({'xla_tpu_enable_log_recorder': 'true'})
    )
    with jtu.capture_stderr() as get_output:
      jax.block_until_ready(compiled_kernel(x))
    output = get_output()
    if dtype == jnp.int32:
      self.assertIn('BEGIN1 x[0] == 42', output)
      self.assertIn('BEGIN2 x[0] == 42 ; x[1] == 24 ; END', output)
    else:
      self.assertIn('BEGIN1 x[0] == f32[] 42', output)

  @parameterized.named_parameters(
      (f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype)
      for shape in (
          (2, 8, 128),
          # test unaligned shapes
          (3,),
          (3, 4),
          (2, 3, 4),
          (2, 9, 129),
      )
      for dtype in (jnp.int32, jnp.uint32, jnp.float32)
  )
  def test_debug_print_vector(self, shape, dtype):
    @functools.partial(
        self.pallas_call,
        out_shape=jax.ShapeDtypeStruct(shape, dtype),
    )
    def kernel(x_ref, o_ref):
      pl.debug_print("{}", x_ref[...])
      o_ref[...] = x_ref[...]

    n = np.prod(shape)
    x = jnp.arange(n, dtype=dtype).reshape(shape)
    compiled_kernel = (
        jax.jit(kernel)
        .lower(x)
        .compile({"xla_tpu_enable_log_recorder": "true"})
    )
    with jtu.capture_stderr() as get_output:
      jax.block_until_ready(compiled_kernel(x))
    output = get_output()
    numbers = [
        int(num)
        for line in output.splitlines()
        if (match := re.search(r"\{(.*)", line))  # extract contents after `{`
        for num in re.findall(r"\d+", match.group(1))
    ]
    # Check if the numbers in the output match the values generated by `arange`.
    self.assertLen(numbers, n)
    self.assertTrue(all(num == i for i, num in enumerate(numbers)))


if __name__ == '__main__':
  absltest.main(testLoader=jtu.JaxTestLoader())
